Skip to content

[PyTorch] Zero-initialize learnable softmax_offset in DotProductAttention#2694

Merged
ksivaman merged 2 commits intoNVIDIA:mainfrom
fjosw:fix/softmax-offset-zero-init-v2
Mar 8, 2026
Merged

[PyTorch] Zero-initialize learnable softmax_offset in DotProductAttention#2694
ksivaman merged 2 commits intoNVIDIA:mainfrom
fjosw:fix/softmax-offset-zero-init-v2

Conversation

@fjosw
Copy link
Contributor

@fjosw fjosw commented Feb 20, 2026

Description

The PyTorch implementation of DotProductAttention initializes the learnable softmax_offset parameter with torch.empty(), which leaves it containing uninitialized memory. Unlike all other TransformerEngineBaseModule subclasses (Linear, LayerNormLinear, LayerNormMLP, GroupedLinear), DotProductAttention does not call self.reset_parameters() in its __init__, so the deferred initialization system that would normally overwrite the torch.empty() contents is never invoked. The JAX implementation explicitly uses nn.initializers.zeros for this parameter. This fix aligns the PyTorch behavior by using torch.zeros() directly. In Megatron-LM this is not a problem because the paramter is initialised explicitly but when used in isolation this can lead to problems.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Change torch.empty() to torch.zeros() when creating the learnable softmax_offset parameter in DotProductAttention, ensuring it is zero-initialized rather than containing uninitialized memory.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

…tion

DotProductAttention used torch.empty() for the learnable softmax_offset
parameter. Unlike all other TransformerEngineBaseModule subclasses,
DotProductAttention does not call reset_parameters() in __init__, so the
deferred initialization that would normally overwrite the empty tensor is
never invoked, leaving the parameter with uninitialized memory.

The JAX implementation explicitly uses nn.initializers.zeros for this
parameter. This aligns the PyTorch behavior by using torch.zeros().

Signed-off-by: Fabian Joswig <fjosw@users.noreply.github.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 20, 2026

Greptile Summary

This PR fixes a bug in DotProductAttention where the learnable softmax_offset parameter was initialized with torch.empty(), leaving it containing uninitialized memory. The fix replaces it with torch.zeros(), ensuring deterministic zero-initialization consistent with the off-by-one branch, the JAX implementation, and how Megatron-LM initializes the parameter externally.

Key changes:

  • transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py: Replace torch.empty(...)torch.zeros(...) for the learnable softmax_offset Parameter (line 442)

Observations:

  • The fix is correct and minimal. All three softmax_type branches now use a well-defined initial value (None, torch.zeros, torch.zeros), eliminating the uninitialized-memory hazard.
  • The PR author notes that no regression test was added to verify the zero-initialization behavior — adding a simple test (e.g., asserting param.data.eq(0).all() after construction with softmax_type="learnable") would help prevent future regressions.

Confidence Score: 5/5

  • This PR is safe to merge — it is a minimal, clearly correct one-line bug fix with no behavioral risk.
  • The change is a single-line swap from torch.empty to torch.zeros, which is unambiguously correct: it aligns with the sibling off-by-one branch, the JAX reference implementation, and the stated intent of zero-initializing the learnable offset. There are no side effects, no API changes, and no risk of regression beyond the absence of a new explicit test.
  • No files require special attention.

Important Files Changed

Filename Overview
transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py Single-line fix replacing torch.empty() with torch.zeros() for the learnable softmax_offset parameter, ensuring zero-initialization instead of uninitialized memory; change is correct and consistent with the off-by-one branch and JAX implementation.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[DotProductAttention.__init__] --> B{softmax_type?}
    B -- vanilla --> C[softmax_offset = None]
    B -- off-by-one --> D["softmax_offset = torch.zeros(...)"]
    B -- learnable --> E["register_parameter('softmax_offset', Parameter(torch.zeros(...)))"]
    E --> F[Parameter zero-initialized]
    E --> G["reset_parameters() NOT called\n(unlike Linear, LayerNormLinear, etc.)"]
    G --> H["Previously: torch.empty() → uninitialized memory ❌"]
    F --> I["Now: torch.zeros() → deterministic zero init ✅"]
    I --> J[Consistent with JAX implementation]
Loading

Last reviewed commit: 1ced390

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, no comments

Edit Code Review Agent Settings | Greptile

@ptrendx
Copy link
Member

ptrendx commented Feb 24, 2026

LGTM, thank you for the fix. @cyanguwa Could you also take a look?

@ptrendx
Copy link
Member

ptrendx commented Feb 24, 2026

/te-ci pytorch

1 similar comment
@ksivaman
Copy link
Member

ksivaman commented Mar 7, 2026

/te-ci pytorch

@ksivaman
Copy link
Member

ksivaman commented Mar 7, 2026

/te-ci pytorch

@ksivaman ksivaman merged commit ab9d60e into NVIDIA:main Mar 8, 2026
20 of 24 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants